华为难题揭榜 · 块 KV 复用的交叉注意力修复

块级 KV 复用的交叉注意力修复方案

作者:吴国立 导师:文泽忆 机构:香港科技大学(广州)

一、问题描述

RAG 文档重排、对话中段匹配等场景使 Prefix Cache 失效;业界改为块级独立 prefill 后拼接 KV 以提升复用率。 拼接后的注意力条件分布与 full-prefill oracle 不一致,导致精度与 TTFT 难以同时达标。

\(E_1\) · RoPE 错位

局部位置与全局位置不一致,\(QK\) 相位错配。

\(E_2\) · 跨块断裂

等价于局部 KV 与全局 KV 的表示差异。

\(E_3\) · 假 Sink

块首汇聚峰拼接后扰动全局注意力 mass。

出题指标:KV 复用率 \(\ge 80\%\),TTFT 降低 \(\ge 70\%\),精度损失 \(< 1\%\); 覆盖 Qwen3-32B、DeepSeek-V3.2、GLM-5 及 RULER、LongBench;在 UCM 框架验证。

二、算法思想概括

块级 KV 复用省下了重复计算,却破坏了「整段上下文下的条件分布」。 本方案在统一全局位置编码之后,用数据结构修复算子纠偏两条线协同, 把混合缓存中的局部表示逐步拉向 full-prefill 的全局 oracle。

  • 治理 \(E_1\):对加载 KV 做全局 RoPE 归位,消除几何错位。
  • 治理 \(E_2\):按预算 \(\tau\) 确定重算集合 \(\mathcal{T}_\tau\),在 \(\mathcal{T}_\tau\) 上调用 \(F\) 做因果前向并写回 KV,这一步将部分局部 KV 物理修正为全局 KV。
  • 治理 \(E_3\):Prefill 剩余段与 Decode 全程调用同一 \(F(Q,L)\),以有偏差意识的注意力分配抑制假 Sink、引导 mass 向全局 KV 倾斜。
  • 同质异量:Prefill 对 token 集合并行调用 \(F\);Decode 对单个新 token 调用 \(F\)——算子相同,仅规模不同。

三、核心洞察

洞察 1 · \(E_2\) 的等价表述: 「KV 的局部属性 vs 全局属性」与 \(E_2\) 同构。 离线块 prefill 得 \((\tilde K,\tilde V)\)(局部条件);在线 full-prefill 需 \((K^{\star},V^{\star})\)(全局条件)。 二者之差即 \(E_2\),也是混合缓存 \(L\) 中“全局KV / 局部KV”标签并存的根源。
洞察 2 · \(\tau\) 修复策略(数据结构侧): 在接缝及末块尾部构造 \(\mathcal{T}_\tau\),调用 \(F\) 在 \(\mathcal{T}_\tau\) 上因果前向并写回 KV,物理修正 \(E_2\)。 当 \(\tau = |B|/2\) 且 B₁ 真前缀时做非对称预算分配,可覆盖 \(B_2\cup\cdots\cup B_m\),结构上逼近无损。
洞察 3 · \(F(Q,L)\)(算法侧): \(E_3\) 来自各层 KV 在“局部KV / 全局KV”标签下的注意力分布差异。 \(F\) 使 query 对局部 KV 的注意力向全局 KV 倾斜,并抑制块首假 Sink。
误差本质对策机制
\(E_1\)几何不一致RoPE 归位解析旋转
\(E_2\)局部 KV \(\neq\) 全局 KV\(\tau\) + \(F\) 写回数据结构物理矫正
\(E_3\)mass 畸变\(F(Q,L)\)注意力分布纠偏

协同:\(\tau\) 改 KV 本体,\(F\) 改 attention 映射。 当 \(\tau \ge |B|/2\) 且非对称覆盖时,\(L \approx L^{\star}\),\(F\) 内各抑制/倾斜项为空,自动退化为标准自注意力,详见下文分析。

四、方法论(哲学概括)

本质:省算破坏整体性;修复是以最小重算代价重建上下文关系。

两点论:兼顾复用、时延、精度;抓住边界、层敏感性与假 Sink。

重点论:在有限 \(\tau\) 下结构化修复,而非全量重算。

量变→质变:\(\tau\) 增大使物理修复覆盖后缀;\(F\) 的纠偏项随 \(L\to L^{\star}\) 自动消失。

五、算法流程示意图

块级 KV 复用流程:B1 全全局 KV,B2…Bm 局部与全局混合,τ 接缝窗口,F(Q,L) 统一算子,Decode 阶段
全局 KV(已修复) 局部 KV(待纠偏) 虚线 · τ 接缝窗口

圆点表示各层 KV 状态;边界 τ 窗口经 \(F\) 物理重算后由黑变蓝;Decode 阶段继续调用同一 \(F\)。

六、符号描述

5.1 序列与分块

\(S = B_1 \oplus B_2 \oplus \cdots \oplus B_m\),其中第 \(i\) 块为 token 区间

\[ B_i = x_{s_i : e_i} \]

\(s_i\):块 \(B_i\) 在整段序列中的起始全局下标(含);
\(e_i\)结束下标(不含,半开区间,与 Python 切片一致)。
例:\(s_1=0,\, e_1=384\) 表示 \(B_1\) 占据位置 \(0,1,\ldots,383\)。

5.2 混合缓存 \(L\)

对每个全局位置 \(t\)、层 \(\ell\):

\[ L[t,\ell] = \bigl(\hat K_{t,\ell},\; \hat V_{t,\ell},\; m_t\bigr) \]

  • \(\hat K_{t,\ell}, \hat V_{t,\ell}\):拼接加载后的 KV;
  • \(m_t = (\type_t,\, b_t,\, g_t)\):
    • \(\type_t \in \{\mathrm{global},\mathrm{local}\}\) — 该位置 KV 是否已等价于 full-prefill oracle;
    • \(b_t\) — 所属块编号;
    • \(g_t\) — 全局位置(RoPE 与 \(F\) 使用)。

紧凑写法(说明型,非公式定义):

\[ L \;\approx\; [\;\underbrace{K^{\mathrm{global}}V^{\mathrm{global}}}_{B_1\;\text{全局KV}},\;\underbrace{K^{\mathrm{local}}V^{\mathrm{local}}}_{B_2\;\text{加载时为局部KV}},\;\ldots\;] \]

上标 global / local 表示 KV 的质量类型, 下标表示块编号——并非每个块固定为某一类: B₁ 在真前缀条件下通常全为 global;B₂ 加载时多为 local,经 \(\mathcal{T}_\tau\) 上 \(F\) 写回后可变为 global。

5.3 重算集合 \(\mathcal{T}_\tau\)

接缝 \(B_i \mid B_{i+1}\) 的对称窗口:

\[ \mathcal{T}^{(i)}_{\mathrm{seam}} = [\,s_i,\, e_i)\cap[e_i-\tau,\,e_i) \;\cup\; [\,s_{i+1},\, e_{i+1})\cap[s_{i+1},\, s_{i+1}+\tau) \]

非对称规则(B₁ 为真前缀):接缝 \(B_1|B_2\) 不重算 \(B_1\) 尾部 \(\tau\) token,预算补至 \(B_m\) 尾部 \(\tau\) token。

重算占比 \(|\mathcal{T}_\tau|/N\);工程目标:\(\le 5\%\) 且精度达标。

5.4 算子 \(F(Q,L)\) 的功能性表述

\(F\) 是在标准自注意力之上的结构感知扩展,输入 query 集合 \(Q\) 与混合缓存 \(L\),输出与 oracle 一致的隐藏表示:

\[ F(Q,\,L) \;=\; \Attn\!\bigl(Q,\, \hat K,\, \hat V;\, \mathcal{M}_{\mathrm{corr}}(L)\bigr) \;\approx\; \Attn(Q,\, K^{\star},\, V^{\star}) \]

其中 \(\mathcal{M}_{\mathrm{corr}}(L)\) 为由 \(L\) 元数据导出的纠偏项,明确为:

\[ \mathcal{M}_{\mathrm{corr}}(L) = \{M_{\mathrm{local}}(j),\; M_{\mathrm{sink}}(j)\} \]

\[ M_{\mathrm{local}}(j)=\mathbb{1}[\type_j=\mathrm{local}], \qquad M_{\mathrm{sink}}(j)=\mathbb{1}[j\in\mathcal{S}'] \]

  • Prefill · 物理修复模式:对 \(t\in\mathcal{T}_\tau\),\(F\) 执行完整层间因果前向,写回 \((\hat K_{t,\ell},\hat V_{t,\ell})\) 并将 \(\type_t\leftarrow\mathrm{global}\);
  • Prefill / Decode · 推理模式:对 query 仅做 attention 前向(不写回),用 \(\mathcal{M}_{\mathrm{corr}}(L)\) 纠偏 \(E_3\);
  • 退化:当 \(\tau \ge |B|/2\) 且非对称覆盖 \(B_{2:m}\) 时,\(L\approx L^{\star}\),\(\mathcal{M}_{\mathrm{corr}}(L)=\varnothing\),\(F\) 等价于标准 \(\mathrm{Softmax}(QK^{\top}/\sqrt{d})V\)。

\(L^{\star}\):对完整 \(S\) 一次 full-prefill 的 oracle 缓存(对照基准,非在线必算)。

七、具体方案

步骤操作输出误差
块匹配,加载 KV,构造 \(L^{(0)}\) 及元数据 \(m_t\) 混合缓存
全局 RoPE 归位:\(\hat K\leftarrow R(g_t)R^{-1}(t_{\mathrm{loc}})\hat K\) \(L^{(1)}\) \(E_1\)
在 \(\mathcal{T}_\tau\) 上调用 \(F\) 做因果前向并写回 KV,\(\type_t\leftarrow\mathrm{global}\) \(L^{(2)}\) \(E_2\)
Prefill 剩余段 + Decode:调用 \(F(Q,L^{(2)})\) 推理(单 token 或批量 query) 隐藏态 / 生成 \(E_3\)

6.1 \(F(Q,L)\) 设计

\(F\) 在标准 logits 上叠加可开关的纠偏项,核心理念: 局部 KV 降权 / 全局 KV 相对倾斜,块首假 Sink 抑制;不含额外边界结构项——跨块语义由步骤 ③ 的 \(F\) 物理写回负责。

Logits 形式(单层)

\[ s_{ij} = \frac{\langle q_i,\,\hat k_j\rangle}{\sqrt{d}} + \underbrace{\alpha\,\mathbb{1}[\type_j=\mathrm{local}]}_{\text{局部 KV 倾斜项}} + \underbrace{\beta\,\mathbb{1}[j\in\mathcal{S}']}_{\text{假 Sink 抑制项}} \]

\[ o_i = \sum_j \mathrm{softmax}_j(s_{ij})\,\hat v_j \]

  • \(\alpha < 0\)(或对全局 key 等价增权):引导 mass 离开尚未修复的局部 KV;
  • \(\beta < 0\):抑制非首块块首 token 集合 \(\mathcal{S}'\),保留序列最前 \(s\) 个真 sink;
  • 步骤 ③ 写回后 \(\type_j=\mathrm{global}\) 增多 → \(\alpha\) 项自动减弱。

\(F\) 的退化(非独立 regime,而是算子内项为空)

定义纠偏掩码集合(与 §5.4 一致):

\[ \mathcal{M}_{\mathrm{corr}}(L) = \{M_{\mathrm{local}}(j),\;M_{\mathrm{sink}}(j)\}, \;\; M_{\mathrm{local}}(j)=\mathbb{1}[\type_j=\mathrm{local}], \;\; M_{\mathrm{sink}}(j)=\mathbb{1}[j\in\mathcal{S}'] \]

当 \(\mathcal{T}_\tau\) 足够大(如 \(\tau\ge |B|/2\) 且非对称覆盖 \(B_{2:m}\))时:

  • 步骤 ③ 已将相关位置写回为全局KV → \(\type_j=\mathrm{local}\) 的集合为空;
  • 假 Sink 经抑制与写回消除 → \(\mathcal{S}'\) 有效为空;
  • 故 \(\alpha\) 项与 \(\beta\) 项均为空集上的零映射

\[ F(Q,L) \;\equiv\; \mathrm{Softmax}\!\left(\frac{QK^{\top}}{\sqrt{d}}\right)V \]

自动退化为标准 Transformer 自注意力。 此时计算上仍对后缀做了完整因果 prefill,复用上主要保留 \(B_1\) 缓存收益。

Prefill vs Decode:同一 \(F\);Prefill 在 \(\mathcal{T}_\tau\) 含 KV 写回,Decode 仅追加 query,同质异量。

八、结论与指标

KV 复用率 ≥ 80% TTFT ↓ ≥ 70% 精度损失 < 1% UCM 可复现
主命题: 通过 \(\tau\) 物理修复 + \(F(Q,L)\) 分布纠偏,使 \(F(Q,L)\approx \Attn(Q,K^{\star},V^{\star})\)。
工程主战场: 有限 \(\tau\)(重算 \(\le 5\%\))下,\(F\) 的倾斜与 Sink 抑制仍满足精度指标; 当 \(\tau \ge |B|/2\) 且非对称覆盖 \(B_{2:m}\) 时,\(F\) 退化验证框架正确性。
相对 CacheBlend: 重算预算结构化到块接缝与末块尾部,并以 \(F\) 显式处理假 Sink。

九、参考资料

  1. Chaspark 揭榜页(难题 4 原文入口) 块 KV 复用的交叉注意力修复问题
  2. UCM(Unified Cache Management) github.com/ModelEngine-Group/unified-cache-management
  3. UCM CacheBlend 文档 ucm.readthedocs.io — CacheBlend
  4. CacheBlend(主基线论文) arXiv:2405.16444 · 实现代码
  5. StreamingLLM(真 Sink 保留思想): arXiv:2309.17453
  6. LongBench(长上下文评测): github.com/THUDM/LongBench
  7. RULER(长上下文评测): github.com/hsiehjackson/RULER